package pers.hll.aigc4chat.server.controller.model;

import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import pers.hll.aigc4chat.base.constant.AuthType;
import pers.hll.aigc4chat.base.exception.Assert;
import pers.hll.aigc4chat.base.util.XmlUtil;
import pers.hll.aigc4chat.base.xml.BaiduConfig;
import pers.hll.aigc4chat.server.base.R;
import pers.hll.aigc4chat.server.bean.command.BaiduConfigCommand;
import pers.hll.aigc4chat.server.converter.BaiduConfigConverter;
import pers.lys.aigc4chat.model.baidu.BaiduApi;
import pers.lys.aigc4chat.model.baidu.constant.ModelName;
import pers.lys.aigc4chat.model.baidu.constant.Role;
import pers.lys.aigc4chat.model.baidu.request.body.ChatReqBody;
import pers.lys.aigc4chat.model.baidu.request.body.Message;
import pers.lys.aigc4chat.model.baidu.response.body.ChatRespBody;

import java.io.IOException;
import java.util.Collections;

/**
 * Baidu控制器
 *
 * @author hll
 * @since 2024/05/07
 */
@RestController
@RequiredArgsConstructor
@RequestMapping("/baidu")
@Tag(name = "BaiduController", description = "百度控制器")
public class BaiduController {

    @PostMapping("/chat")
    @Operation(summary = "对话")
    public R<ChatRespBody> chat(
            @RequestParam(required = false, defaultValue = ModelName.EB_INSTANT) String modelName,
            @RequestBody ChatReqBody chatReqBody) {
        return R.data(BaiduApi.chat(chatReqBody, modelName));
    }

    @PostMapping("/easy-chat")
    @Operation(summary = "单轮对话")
    public R<ChatRespBody> easyChat(@RequestParam @Parameter(description = "内容") String content) {
        ChatReqBody chatReqBody = new ChatReqBody();
        Message message = new Message(Role.USER, content);
        chatReqBody.setMessages(Collections.singletonList(message));
        return R.data(BaiduApi.chat(chatReqBody, ModelName.EB_INSTANT));
    }

    @PostMapping("/config")
    @Operation(summary = "编辑鉴权配置")
    public R<String> editConfig(@RequestBody @Validated BaiduConfigCommand baiduConfigCommand) {
        checkBaiduConfigCommand(baiduConfigCommand);
        try {
            XmlUtil.writeXmlConfig(BaiduConfigConverter.from(baiduConfigCommand));
        } catch (IOException e) {
            return R.fail(e.getMessage());
        }
        return R.success();
    }

    @PostMapping("/refresh-access-token")
    @Operation(summary = "更新access-token", description = "只有在 authType = ACCESS_TOKEN 时可用")
    public R<String> refreshAccessToken() {
        BaiduApi.refreshAccessToken();
        return R.success();
    }

    @GetMapping("/config")
    @Operation(summary = "获取鉴权配置")
    public R<BaiduConfig> config() {
        return R.data(XmlUtil.readXmlConfig(BaiduConfig.class));
    }

    /**
     * 校验BaiduConfigCommand
     *
     * @param baiduConfigCommand 鉴权配置
     */
    private void checkBaiduConfigCommand(BaiduConfigCommand baiduConfigCommand) {
        if (AuthType.ACCESS_TOKEN == baiduConfigCommand.getAuthType()) {
            Assert.notEmpty(baiduConfigCommand.getApiKey(), "apiKey不能为空");
            Assert.notEmpty(baiduConfigCommand.getSecretKey(), "secretKey不能为空");
        }
        if (AuthType.AK_SK == baiduConfigCommand.getAuthType()) {
            Assert.notEmpty(baiduConfigCommand.getAccessKeyId(), "accessKeyId不能为空");
            Assert.notEmpty(baiduConfigCommand.getSecretAccessKey(), "secretAccessKey不能为空");
        }
    }
}
